{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import sys\n",
        "if 'google.colab' in sys.modules:\n",
        "    !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/atari_util.py\n",
        "    !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week08_pomdp/env_pool.py\n",
        "\n",
        "    !pip install -q gymnasium[atari,accept-rom-license]\n",
        "\n",
        "    !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
        "    !touch .setup_complete\n",
        "# If you are running on a server, launch xvfb to record game videos\n",
        "# Please make sure you have xvfb installed\n",
        "import os\n",
        "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
        "    !bash ../xvfb start\n",
        "    os.environ['DISPLAY'] = ':1'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from IPython.core import display\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Kung-Fu, recurrent style\n",
        "\n",
        "In this notebook we'll once again train RL agent for for Atari [KungFuMaster](https://gymnasium.farama.org/environments/atari/kung_fu_master/), this time using recurrent neural networks.\n",
        "\n",
        "![img](https://upload.wikimedia.org/wikipedia/en/6/66/Kung_fu_master_mame.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
            "Observation shape: (1, 42, 42)\n",
            "Num actions: 14\n",
            "Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']\n"
          ]
        }
      ],
      "source": [
        "import gymnasium as gym\n",
        "from atari_util import PreprocessAtari\n",
        "\n",
        "\n",
        "def make_env():\n",
        "    env = gym.make(\"KungFuMasterDeterministic-v0\", render_mode=\"rgb_array\")\n",
        "    env = PreprocessAtari(env, height=42, width=42,\n",
        "                          crop=lambda img: img[60:-30, 15:],\n",
        "                          color=False, n_frames=1)\n",
        "    return env\n",
        "\n",
        "\n",
        "env = make_env()\n",
        "\n",
        "obs_shape = env.observation_space.shape\n",
        "n_actions = env.action_space.n\n",
        "\n",
        "print(\"Observation shape:\", obs_shape)\n",
        "print(\"Num actions:\", n_actions)\n",
        "print(\"Action names:\", env.unwrapped.get_action_meanings())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/home/jheuristic/anaconda3/lib/python3.6/site-packages/scipy/misc/pilutil.py:482: FutureWarning: Conversion of the second argument of issubdtype from `int` to `np.signedinteger` is deprecated. In future, it will be treated as `np.int64 == np.dtype(int).type`.\n",
            "  if issubdtype(ts, int):\n",
            "/home/jheuristic/anaconda3/lib/python3.6/site-packages/scipy/misc/pilutil.py:485: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
            "  elif issubdtype(type(size), float):\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAANEAAAEICAYAAADBfBG8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFmVJREFUeJzt3XvUHHV9x/H3hyBoASHcEgi3wAGO4CVGxFTKRbyFVAXaqsFWUWkJlVA80FMIKFLUAirQKBUImnIRQSqi1BNQCnhpEeRiCJcIJIAQckMIBAVpE7/9Y2Zhstl9nnl2dp+Z2f28ztmzszOzu99J5ru/3/xmnu8oIjCzzm1QdgBmdeckMivISWRWkJPIrCAnkVlBTiKzgpxEfUjSTpJ+J2lM2bEMAidRAZKmS7pd0u8lrUynPyVJZcYVEY9HxKYRsbbMOAaFk6hDkk4EZgNfBsYD44BjgP2AjUoMzUZbRPgxwgewOfB74C+HWe/PgV8Bq4EngNMzy3YBAvhEumwVSRK+FVgAPAuc3/R5nwQWpuv+CNi5zfc2PnvD9PVPgC8AtwK/A/4T2Aq4Io3tDmCXzPtnpzGtBu4C9s8sew1waRrDQuCfgCWZ5dsD1wBPAY8C/1D2/1fP94eyA6jjA5gKrGnspEOsdxDwBpIW/43ACuCwdFljR78QeDXwHuAPwPeBbYEJwErgwHT9w4BFwOuADYHPALe2+d5WSbQI2C39AXgAeAh4V/pZlwH/nnn/36RJtiFwIrAceHW67Czgp8BYYIc04ZekyzZIk+40ktZ4V+AR4L1l/5/1dH8oO4A6PtKdbHnTvFvT1uNF4IA27/tX4Lx0urGjT8gsfxr4cOb1NcCn0+nrgaMyyzYAXqBFa9QmiU7NLD8HuD7z+v3A/CG2dxXwpnR6naQA/jaTRG8DHm9676xsgvbjw8dEnXka2FrSho0ZEfH2iNgiXbYBgKS3SbpF0lOSniPprm3d9FkrMtMvtni9aTq9MzBb0rOSngWeAUTSYuWR93uQdKKkhZKeS79r80zc25N09Rqy0zsD2zdiTN97CsnxYt9yEnXmF8BLwKHDrPdt4Dpgx4jYnKTr1unI3RPAjIjYIvN4TUTc2uHntSRpf+Ak4EPA2PSH4TleiXsZSTeuYcemGB9tinGziJjWzRirxknUgYh4Fvhn4OuS/krSppI2kDQJ2CSz6mbAMxHxB0n7Ah8p8LUXArMk7Q0gaXNJHyzwee1sRnK89xSwoaTTgNdmll+dxjFW0gRgZmbZL4HVkk6S9BpJYyS9XtJbexBnZTiJOhQRXwJOIBmdWknSPbqI5Fe80Tp8CjhD0vMkB9tXF/i+a4GzgaskrQbuAw7peAPa+xHJ8ddDwG9IBjuyXbYzgCUkI2//BXyXpFUmkvNS7wcmpct/C3yDpDvYt5Qe/Jl1RNLfA9Mj4sCyYymLWyIbEUnbSdov7b7uSTIEfm3ZcZVpw+FXMVvHRiTd1okkQ/pXAV8vNaKS9aw7J2kqyZnvMcA3IuKsnnyRWcl6kkTp1cMPAe8mOQi9AzgiIh7o+peZlaxX3bl9gUUR8QiApKtIzqm0TCJJHt2wKvptRGwz3Eq9GliYwLrDoktoOrMu6WhJd0q6s0cxmBX1mzwr9aolanVWfp3WJiLmAHPALZHVW69aoiWseznIDsDSHn2XWal6lUR3ALtLmihpI2A6yTVkZn2nJ925iFgjaSbJJSRjgLkRcX8vvsusbJW47MfHRFZRd0XEPsOt5Mt+zAqqxWU/xx9/fNkh2ACaPXt2rvXcEpkVVIuWaLTMmDEDgIsuuqjtsqzm9ZrXGelyqye3RKlWSdJq2UUXXfTyzp+dn03ATpZbfTmJUm4VrFNOohyyCTZjxowhu3btllv/chKZFeSBhZyGGyRoXset0eBwS5RDnoRw0gyuWlz2MxonW0c6PJ1nHQ9x19vs2bNzXfbjJDJrI28SuTtnVpCTyKwgj85VyNhZY9ebt+rMVSVEYiPhlqgiGgm06sxVLz+y8626nERmBXWcRJJ2TG9gtVDS/ZKOT+efLulJSfPTR1/fm8asyDHRGuDEiLhb0mbAXZJuTJedFxFfKR6eWfV1nEQRsYzkrmlExPOSFpL/1odmfaMrx0SSdgHeDNyezpopaYGkuZJaHhm7Auq6sgMJjUd2vlVX4SFuSZvyyl2uV0u6APg8ScXTz5PcqfqTze9zBdT1OWHqqVBLJOlVJAl0RUR8DyAiVkTE2oj4I3AxSXF7s75VZHROwDeBhRFxbmb+dpnVDie5t6hZ3yrSndsP+Chwr6T56bxTgCPSu2gH8BjgvxGwvlZkdO6/aX33h3mdh2NV5D/hGNrAXjt374NHrPP6DXteOaLl3fiMPN9RthkzZrSsMeFEeoUv+7EhOVmG5ySy3IYqbjnInESWm4tOtuYksiE5YYbnGgs2rEEdnctbY2FgR+csv0FJmk65O2dWkJPIrCAnkVlBA3NM1HyPoVZn4lstzz5nNc9rfNasWQ/3ahO64swzdy87hL4zUC3RcAfIeQ6gszfpyvse628DlUTDnfNoXt5q/Tzr2GAZqCRqbkVaLW+ebl6/1fvdGg22gUqiZp3c1a75Pa2Ol2yw+IoFszZG7YoFSY8BzwNrgTURsY+kLYHvALuQ/HXrhyLCVTisL3WrO/eOiJiUydqTgZsiYnfgpvS1WV/q1XmiQ4GD0ulLgZ8AJ/Xou0ZkJOeDWs1v9Z6sQ37+89HZkA5dv//+ZYfQd7qRRAH8OD2uuSitJzcurZBKRCyTtG0Xvqdrit4m0iyrG925/SJiMnAIcKykA/K8qcwKqCM9X9TpOjYYCidRRCxNn1cC15IUa1zRqD+XPq9s8b45EbFPntGPbhvplQvtXvv8kEHxCqibpHeEQNImwHtIijVeBxyZrnYk8IMi39Ntrc71DLXcbCiFzhNJ2pWk9YHk+OrbEfFFSVsBVwM7AY8DH4yIZ4b4HJ8nssoZlfNEEfEI8KYW858G3lnks83qohZXLJiVpH9qLEz+wuSyQ7ABdPdn7s61Xi2SaNsdKnWayWwdtUiiDa4e6IvNreJqkUTzd5g//EpmJalFEo3faXzZIdgAWsrSXOu5n2RWUC1aIg8sWJX5PJFZe7nOE7k7Z1aQk8isoFocE90w2Vcs2Oibene+KxbcEpkV5CQyK8hJZFZQLY6JJs3zFQtWgpy7nVsis4I6bokk7UlS5bRhV+A0YAvg74Cn0vmnRMS8jiMEPvLx04ZdZ9aJxwFw5jlfK/JVhTiGfosh327bcRJFxIPAJABJY4AnSeotfAI4LyK+0ulnd2LtSWuTiRKvEHIMgxlDt46J3gksjojfSOrSR47MmLPHJBPnlPL1jmGAY+hWEk0Hrsy8ninpY8CdwImjUcx+0H79HEN1Yig8sCBpI+ADwH+ksy4AdiPp6i2jzW9Btyugjjl7zCu/PiVxDIMZQzdaokOAuyNiBUDjGUDSxcAPW70prdk9J12v8FXcg/br5xiqE0M3kugIMl05Sds1itkDh5NURO25QeuHO4bqxFAoiST9CfBuIFtz90uSJpHcLeKxpmU9M2i/fo6hOjEUrYD6ArBV07yPFoqoQ4P26+cYqhNDLS77yWPQfv0cQ3Vi6JskGrRfP8dQnRj6JokG7dfPMVQnhr5JokH79XMM1Ymhb5Jo0H79HEN1YuibJBq0Xz/HUJ0Y+iaJBu3XzzFUJ4ZaFG9cvnzaaIVi9rLx4+e5eKPZaKhFd+6Wyb61ilWXWyKzgpxEZgU5icwKqsUx0TvunlR2CDaIxvtOeWajohYtUZ66c2bdl6/unFsis4JyJZGkuZJWSrovM29LSTdKejh9HpvOl6SvSlokaYEk31zI+lrelugSYGrTvJOBmyJid+Cm9DUk1X92Tx9Hk5TQMutbuZIoIn4GPNM0+1Dg0nT6UuCwzPzLInEbsIWk7boRrFkVFTkmGtcojZU+N66XnQA8kVlvSTpvHd0u3mhWll6MzrUqxr3eVdrdLt5oVpYiLdGKRjctfV6Zzl8C7JhZbwcg31krsxoqkkTXAUem00cCP8jM/1g6SjcFeC5TEdWs7+Tqzkm6EjgI2FrSEuBzwFnA1ZKOAh4HPpiuPg+YBiwCXiC5X5FZ38qVRBFxRJtF72yxbgDHFgnKrE58xYJZQU4is4KcRGYFOYnMCnISmRXkJDIryElkVpCTyKwgJ5FZQU4is4KcRGYFOYnMCnISmRXkJDIryElkVpCTyKwgJ5FZQcMmUZvqp1+W9Ou0wum1krZI5+8i6UVJ89PHhb0M3qwK8rREl7B+9dMbgddHxBuBh4BZmWWLI2JS+jimO2GaVdewSdSq+mlE/Dgi1qQvbyMpi2U2kLpxTPRJ4PrM64mSfiXpp5L2b/cmV0C1flGoAqqkU4E1wBXprGXAThHxtKS3AN+XtHdErG5+bzcroN58w5SXpw+eeluRj6p1DEOpenx11nFLJOlI4H3AX6dlsoiIlyLi6XT6LmAxsEc3Am0nu3OUpQoxjETd4q26jpJI0lTgJOADEfFCZv42ksak07uS3F7lkW4EmlcVdpAqxJBVtXj6zbDduTbVT2cBGwM3SgK4LR2JOwA4Q9IaYC1wTEQ035KlJxpdlDJ3mCrE0E6VY6u7YZOoTfXTb7ZZ9xrgmqJBdaKxc5TZ369CDK0cPPU2J08P1eLGx0M5eOptfO3tZ7z8+rhbBzOG4Sz41rSXpz/9Ld9Iupt82Y9ZQX2RRMfdeto6z4Maw1AarY9boe6rfXcOYI97FnAc5e4cZcVw/rmvBWDmCeudimux3lc4P72X+3DrW361b4n2uGfBOs+DFEMjgZqnh1ovz/o2MrVPoqwyE6lKMTScf+5rnSyjoLbduarsrGXG0eiSNRJluIRpXt+6oy9aoofe9MayQyg1huzxzcwTVrd83ZxAPibqntq2RNZacyvjVqf3+qIlstYtS3OrNNS61rnaJ9Ggd+WympOjMbCQTSYnUPfVPomyB/Zl7cxViGEo2WSy7qt9Etm6nCijr/YDC1X45a9CDFl77bXXeleS33zDlMpdXd4v3BKZFVTbJFo790DWzj1wnddlxVF2DMNxK9Rbte/OAex2/NiyQ6hEDA0HT71t3fND5z7gY6Ue6rQC6umSnsxUOp2WWTZL0iJJD0p6b68Cb6UKO3IVYmjmBOqtTiugApyXqXQ6D0DSXsB0YO/0PV9vFC7ptsWzV7F49ip2O34si2ev6sVX5I6j7BisXHlqLPxM0i45P+9Q4KqIeAl4VNIiYF/gFx1HmEMVduIqxGDlKDKwMDMtaD9XUqMPMwF4IrPOknTeerpVAbWx45bZjapCDFaeTpPoAmA3YBJJ1dNz0vlqsW7L6qYRMSci9omIfTqMYT1V2ImrEIMvOh1dHSVRRKyIiLUR8UfgYpIuGyQtz46ZVXcAlhYL0YrwoELvdVoBdbvMy8OBxsjddcB0SRtLmkhSAfWXxUIcWhV++asQg5Wn0wqoB0maRNJVewyYARAR90u6GniApND9sRGxtjehWyvuyo2+rlZATdf/IvDFIkHlUZVf/6rEYeWp7WU/rVRhiLkKMdjoUnpXlHKDGOb+RENd97Xf8icB+J/xLUfSR0UVYsiqak3wurn5hil35Rk9rsW1cydMbn/r19vnfRZIduS3Tfv8aIVUuRiybr4heR7q382G1/h3HE7tu3NV2GmrEEMr7/uX+WWHMBBq0Z0zK0n/dOd+eMqkskOwAZS3Ja99d86sbE4is4KcRGYFeWDBrD0PLJgV4YEFs1FSi+7c8uXThlps1hPjx8/rn+7cLZN95t2qy905s4KcRGYFOYnMCuq0Aup3MtVPH5M0P52/i6QXM8su7GXwZlWQZ2DhEuB84LLGjIj4cGNa0jnAc5n1F0dEV0/svONunyeyEozPV6iqUAVUSQI+BBw8gtBGbPz4eb38eLNCig5x7w+siIiHM/MmSvoVsBr4TET8vNUbJR0NHJ3nS67cfvuCYZqN3BFLu9QSDfc9wJWZ18uAnSLiaUlvAb4vae+IWK+CYETMAeaAr52zeus4iSRtCPwF8JbGvLSQ/Uvp9F2SFgN7AIXqbeeVPXZqnKBtNc8xlB/DaMTR7vu6/W9RZIj7XcCvI2JJY4akbRq3UpG0K0kF1EeKhTgyrf5RRvuKB8dQrRh6HUeeIe4rSW6NsqekJZKOShdNZ92uHMABwAJJ9wDfBY6JiGe6Fq1ZBXVaAZWI+HiLedcA1xQPy6w+fMWCWUF9mUTZ/m5ZV4A7hurE0Os4avGnECNRhasbHMNgxVCLP8rzyVYrwxFLl+b6o7xaJJFZSfrnL1uT619H5vI//WcAPvqLz3U7GMdQwxg6i2NmrrX6cmDBbDQ5icwKchKZFVSLY6Lx229Vynu7xTFUJwbIH8fyfH8J4ZbIrKhatETbjB/ZHbrPPfuznHDS5QBcfulnOeGk0b+TnWOoTgydxjGwLdEVl5zFuHGbvPx63LhNuOKSsxzDAMfQ6zjq0RJtu8WI39P8j9TJZxTlGKoTQy/jqMUVCyO9lfy3Lzljndcf+fhpIw+qIMdQnRg6jePmG6b0z2U/I00is27Im0R9d0xkNtry/Hn4jpJukbRQ0v2Sjk/nbynpRkkPp89j0/mS9FVJiyQtkDS51xthVqY8LdEa4MSIeB0wBThW0l7AycBNEbE7cFP6GuAQkgIlu5PUlbug61GbVciwSRQRyyLi7nT6eWAhMAE4FLg0Xe1S4LB0+lDgskjcBmwhabuuR25WESMa4k7LCb8ZuB0YFxHLIEk0Sdumq00Ansi8bUk6b1nTZ+WugHrzDVNGEqbZqMqdRJI2Jank8+mIWJ2U4W69aot5642+uQKq9Ytco3OSXkWSQFdExPfS2Ssa3bT0eWU6fwmwY+btOwA5L6Awq588o3MCvgksjIhzM4uuA45Mp48EfpCZ/7F0lG4K8Fyj22fWlyJiyAfwZyTdsQXA/PQxDdiKZFTu4fR5y3R9Af8GLAbuBfbJ8R3hhx8VfNw53L4bEfW4YsGsJL5iwWw0OInMCnISmRXkJDIrqCp/lPdb4Pfpc7/Ymv7Znn7aFsi/PTvn+bBKjM4BSLozz0hIXfTT9vTTtkD3t8fdObOCnERmBVUpieaUHUCX9dP29NO2QJe3pzLHRGZ1VaWWyKyWnERmBZWeRJKmSnowLWxy8vDvqB5Jj0m6V9J8SXem81oWcqkiSXMlrZR0X2ZebQvRtNme0yU9mf4fzZc0LbNsVro9D0p674i/MM+l3r16AGNI/mRiV2Aj4B5grzJj6nA7HgO2bpr3JeDkdPpk4Oyy4xwi/gOAycB9w8VP8mcw15P8ycsU4Pay48+5PacD/9hi3b3S/W5jYGK6P44ZyfeV3RLtCyyKiEci4n+Bq0gKnfSDdoVcKicifgY80zS7toVo2mxPO4cCV0XESxHxKLCIZL/MrewkalfUpG4C+LGku9ICLNBUyAXYtu27q6ld/HX+P5uZdkHnZrrXhben7CTKVdSkBvaLiMkkNfeOlXRA2QH1UF3/zy4AdgMmkVSeOiedX3h7yk6ivihqEhFL0+eVwLUk3YF2hVzqoq8K0UTEiohYGxF/BC7mlS5b4e0pO4nuAHaXNFHSRsB0kkIntSFpE0mbNaaB9wD30b6QS130VSGapuO2w0n+jyDZnumSNpY0kaRy7y9H9OEVGEmZBjxEMipyatnxdBD/riSjO/cA9ze2gTaFXKr4AK4k6eL8H8kv81Ht4qeDQjQV2Z7L03gXpImzXWb9U9PteRA4ZKTf58t+zAoquztnVntOIrOCnERmBTmJzApyEpkV5CQyK8hJZFbQ/wPTMFRqoBLrRQAAAABJRU5ErkJggg==",
            "text/plain": [
              "<matplotlib.figure.Figure at 0x7fc9f433a8d0>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAFz5JREFUeJzt3X20HHV9x/H35z5xQxIeEkIMJAUfooItpi1G6sORolhEFGzViihpy7HtsfRYH9qqfcJWrZ6K2HP06EFFUquAjzVVasmJIIVaHsSIQagBBBMTEhACuXm6T9/+MXPL3jtzc/fe3Z3dze/zOmfP3f3N7M539u53Z+a3M7+vIgIzS09PuwMws/Zw8pslyslvlignv1minPxmiXLymyXKyZ8wSSdKCkl97Y5lNiRdIOm6dsfR7Zz8TSTpBkmPSTqswmWGpGdUtbyqlX1BRcQXIuLl7YzrUODkbxJJJwIvBgJ4dVuD6SDK+HPWgfxPaZ4Lgf8BrgTW1E6QtFjSv0t6QtJtkt4v6aaa6c+WtF7So5L+V9Lra6ZdKekTkr4labekWyQ9PZ92Yz7bDyUNSfrdqUFJ6pH015IelLRT0r9IOnLKbH8gaZuk7ZLeWfPc1ZJuz+PeIemjNdNOk/TfknZJ+qGk02um3SDpA5JuBvYC75V0+5S43i5pXX7/lZJ+kC9ni6RLamadWMdd+Tr+hqTfm/L+vSB/Xx/P/75gSiz/IOnm/P27TtIxU9+nJEWEb024AfcCbwV+HRgBltZMuzq/HQ6cDGwBbsqnzc8f/z7QB/wa8AjwnHz6lcCjwOp8+heAq2teO4BnHCSuP8hjexqwAPga8Pl82on586/K4/gV4GHgZfn07wFvzu8vAE7L7x8P/AI4m2wDcmb+eEk+/QbgZ8Bz8piPBHYDK2viug14Q37/9HzZPcApwA7gvCkx9tU89/dq3r9FwGPAm/NlnZ8/XlwTy33AM4F5+eMPtfvz0gk3b/mbQNKLgBOAL0XE98k+bG/Mp/UCvwP8XUTsjYgfA2trnn4O8EBEfC4iRiPiDuCrwGtr5vlaRNwaEaNkyb9qFuFdAHw0Iu6PiCHgPcAbpnTyvS8i9kTEj4DPkSUQZF9iz5B0TEQMRcT/5O1vAq6NiGsjYjwi1gO3k30ZTLgyIu7K1+lx4BsTrytpJfBsYB1ARNwQET/KX+tOsi+jl9S5fq8ENkfE5/NlXQXcA7yqZp7PRcRPImIf8CVm9/4dspz8zbEGuC4iHskff5End/2XkG2RttTMX3v/BOD5+e7zLkm7yBL2KTXzPFRzfy/ZVrhexwEP1jx+MI9n6TTxPJg/B+Aisi3mPfnu9Dk1Mb9uSswvApZN85qQvScTXypvBP4tIvYCSHq+pOslPSzpceCPgXp3zaeu38Q6HF/zuJH375DVVT/xdCJJ84DXA72SJj5khwFHSXousAkYBZYDP8mnr6h5iS3AdyPizBaFuI0sWSf8Uh7PjjymiXjuqZm+DSAiNgPn5x12vw18RdLiPObPR8RbDrLcqZeLXgccI2kV2ZfA22umfRH4OPCKiNgv6WM8mfwzXXY6df0m1uHbMzwved7yN+48YIzsWH5VfjsJ+C/gwogYIzvOvkTS4ZKeTdY5OOGbwDMlvVlSf357nqST6lz+DrLj+elcBbxd0lMlLQA+CFyTH0JM+Js8tueQ9T1cAyDpTZKWRMQ4sCufdwz4V+BVkn5LUq+kQUmnS1rONPLlfQX4J7Lj9PU1kxcCj+aJv5r8kCn3MDB+kHW8luz9e6OkvrzT82Sy99UOwsnfuDVkx5Q/i4iHJm5kW7IL8mPri8k6vR4CPk+WkAcAImI38HLgDWRbsYeAD5PtPdTjEmBtvvv9+pLpV+TLvBH4KbAf+NMp83yXrFNwA/CRiJg4geYs4C5JQ8A/k3XQ7Y+ILcC5wHvJknML8OfM/Hn6IvAy4MtTvnzeCvy9pN3A35IdlwOQHxp8ALg5X8fTal8wIn5B1m/yTrJOx78Azqk5BLNpKO8RtQpJ+jDwlIhYM+PMZi3iLX8F8t/xT8nOd9Fqso60r7c7LkubO/yqsZBsV/84YCdwKdlPX2Zt491+s0R5t98sUQ3t9ks6i6wXuBf4TER86GDz9w/Mj8HBoxtZpJkdxP79jzEyvEf1zDvn5M9PW/0E2XndW4HbJK3LT18tNTh4NKeuvniuizSzGdx+68frnreR3f7VwL35OePDZBeunNvA65lZhRpJ/uOZfP72ViafTw2ApD/MLwu9fWRkTwOLM7NmaiT5y44rCj8dRMTlEXFqRJza3z+/gcWZWTM10uG3lckXqCwnvyBkOhraR//NmxpYpJkdjA7sq3veRrb8twEr8wtGBsjOTV/XwOuZWYXmvOWPiFFJFwP/SfZT3xURcVfTIjOzlmrod/6IuJbskkoz6zI+w88sUZVe2BNHzGP/i0+pcpFmSYn/uqHueb3lN0uUk98sUU5+s0Q5+c0S5eQ3S1Slvf3Di4It549Oaovx4iUCkkcXMoio/7Mxm3k7SbPjHr6r/ud6y2+WKCe/WaKc/GaJcvKbJaracftDxNjkDo4YKX7/lHUCtpsGxgttU9cFgLK2duor7wBST7E9hjtsW1ASI/3F/wNADPcW2zqsv6/0MzRa/Lw09PmfxXM77L9tZlVx8pslyslvlignv1miGq3Y8wCwGxgDRiPi1IPOPyL6fj657LzGGomgQmVfk2UdSh3WyTStsvUp70vrLNNtrro19ibHrZH6O/ya0dv/mxHxSBNex8wq5N1+s0Q1mvwBXCfp+5L+sGyG2oo9Y3tcscesUzS62//CiNgm6VhgvaR7IuLG2hki4nLgcoDB5Su65YjY7JDX6NDd2/K/OyV9nax4543Tzg/ElBOxVNbh0YFfEVPjhmk6Kzss9rK4gdJ9vtL/RTuV9F2NT7M+PZ0We4my2MtOYqzqMzTn3X5J8yUtnLgPvBxwLS6zLtHIln8p8HVJE6/zxYj4dlOiMrOWa6Rc1/3Ac5sYi5lVyD/1mSWq0kt6RUknWYd1kE2nGzr3ykx7BmUXdJCVvb893XJGaInS2Nv4GfKW3yxRTn6zRDn5zRLl5DdLlJPfLFGV9vYHEFO+bnx6b2tNe3pvyamzGi22tdUhdnpv6WeoG0/vNbPu5uQ3S5ST3yxRTn6zRFVbsacvGF1cR69Sh3WaAaWdTx0Z51SzKf7SaevTaPGjblifJscY/S7RbWYzcPKbJcrJb5YoJ79Zombs8JN0BXAOsDMifjlvWwRcA5wIPAC8PiIem3Fp40J7p47g2Wm9MmZdrMkluq8EzprS9m5gQ0SsBDbkj82si8yY/Pk4/I9OaT4XWJvfXwuc1+S4zKzF5nrMvzQitgPkf4+dbsZJFXuGXLHHrFO0vMMvIi6PiFMj4tTeBfNbvTgzq9Ncz/DbIWlZRGyXtAzYWdezAnpGpjY2ehqXmf2/WfSfz3XLvw5Yk99fA3xjjq9jZm0yY/JLugr4HvAsSVslXQR8CDhT0mbgzPyxmXWRGXf7I+L8aSa9tMmxmFmFfIafWaIqv6R3bHGhx8/MmqXPl/Sa2Qyc/GaJcvKbJcrJb5aoakt0D4uBrQNVLtIsKRpu7iW9ZnYIcvKbJcrJb5YoJ79Zopz8Zoly8pslyslvlignv1minPxmiapnJJ8rJO2UtKmm7RJJP5e0Mb+d3dowzazZ5lq0A+CyiFiV365tblhm1mpzLdphZl2ukWP+iyXdmR8WHN20iMysEnNN/k8CTwdWAduBS6ebcVLFnj2u2GPWKeaU/BGxIyLGImIc+DSw+iDzPlmxZ74r9ph1ijklf16lZ8JrgE3TzWtmnWnGwTzyoh2nA8dI2gr8HXC6pFVkxYEeAP6ohTGaWQvMtWjHZ1sQi5lVyGf4mSXKyW+WKCe/WaKc/GaJcvKbJcrJb5YoJ79Zopz8Zoly8pslyslvlignv1minPxmiXLymyXKyW+WKCe/WaKc/GaJcvKbJaqeij0rJF0v6W5Jd0l6W96+SNJ6SZvzvx6+26yL1LPlHwXeGREnAacBfyLpZODdwIaIWAlsyB+bWZeop2LP9oi4I7+/G7gbOB44F1ibz7YWOK9VQZpZ883qmF/SicCvArcASyNiO2RfEMCx0zzHRTvMOlDdyS9pAfBV4M8i4ol6n+eiHWadqa7kl9RPlvhfiIiv5c07Jop35H93tiZEM2uFenr7RTZO/90R8dGaSeuANfn9NcA3mh+embXKjEU7gBcCbwZ+JGlj3vZe4EPAlyRdBPwMeF1rQjSzVqinYs9NgKaZ/NLmhmNmVfEZfmaJcvKbJaqeY/6m0Tj07Z18BDE2LwrzxXQHGW3Ut68Y1PhAcb7x3uL6dIve4fre+LGB7l1He5K3/GaJcvKbJcrJb5YoJ79Zoirt8Ot/aA/L//G/J7Vte9cLCvMNH9neDqWBJ4odX8ddekuhbdebVhfbVrYkpKZTyVu8Yv1Qoa1vZ/EyjvsvPK7Q1s0dnanylt8sUU5+s0Q5+c0S5eQ3S5ST3yxR1Z7ee9hh9J749Mlt41VGUJ+ekWJb39Ilhbbx3gqCaRGNF3/RGD7qsEJb7xNl5zCXvGAXvxep8pbfLFFOfrNEOfnNEtVIxZ5LJP1c0sb8dnbrwzWzZqmnw2+iYs8dkhYC35e0Pp92WUR8pN6F7V/Sx0/eMnl4/94DJaeFtvlM0QOLigFsfttTC20aK3lyGzswy8ZGmHZwhP3FeR84r2TeviMKTQMPdd7/zGavnjH8tgMTxTl2S5qo2GNmXayRij0AF0u6U9IV0xXqdMUes87USMWeTwJPB1aR7RlcWvY8V+wx60xzrtgTETsiYiwixoFPA8XrW82sY814zD9dxR5JyyYKdQKvATbN9Fo9IzBv5+ROpZGFnTeAp8aKAQw+XJxvZGGxrZ3XtZ+y+r5C29ED+0rn/e59xYEHLlv95ULbkt7i9fwXrntroa1vdweOumoH1UjFnvMlrSLr530A+KOWRGhmLdFIxZ5rmx+OmVXFZ/iZJcrJb5aoSi/pBegZnfy47NLSaPNgkBqts63sDL82Xtq6cePTCm1/dWZ55fS3/MZ3C233DC8rtG3at7zQ1lNnZR/rbN7ymyXKyW+WKCe/WaKc/GaJqrbDTzDeP7mp3Wfz1Wtq3NB5sfc/Xvwu/8Tml5TOe8ep1xTaHhrbW2j74I/PKrT1HJhDcNZxvOU3S5ST3yxRTn6zRDn5zRLl5DdLVKW9/dELw0dOOXW3Ayv2jA0WTy8em1cyYxcMWrnnzkWl7c/ce2GhbXS4+HHof7BYxccODd7ymyXKyW+WKCe/WaLqqdgzKOlWST/MK/a8L29/qqRbJG2WdI2kknKuZtap6unwOwCcERFD+Si+N0n6D+AdZBV7rpb0KeAisuG8p6XBMfpPmjwg5MjdxYow7e4EHDmi2JO38ITHC2177j2y0Na7r7PO+e0ZLY/nWct2FNru2XFsoS1wh9+hasYtf2SG8of9+S2AM4Cv5O1rgfNaEqGZtUS94/b35iP37gTWA/cBuyJiYnybrUxTwqu2Ys/oE8ULR8ysPepK/rw4xypgOVlxjpPKZpvmuf9fsafviMPnHqmZNdWsevsjYhdwA3AacJSkiT6D5cC25oZmZq1UT8WeJcBIROySNA94GfBh4HrgtcDVwBqgfKTIGhFibKzzf10s67QbGy+Ju9Mu6C8x3l9+GuJLj7mn0HbvL44ptI00PSLrFPX09i8D1krqJdtT+FJEfFPSj4GrJb0f+AFZSS8z6xL1VOy5k6ws99T2+3FxTrOu1fn74GbWEk5+s0RVeknvQN8oJyx+dFLb/b0LCvOpAy/zXXrE7kLbTw8rxt67v7O+T0fnl3f4/crglkLbnkeKP8X6nO1DV2d9Us2sMk5+s0Q5+c0S5eQ3S1SlHX6j4z3sHJrSSdaBnXtjC4tB9ajYcdYz0gVn+C0oqS0O3HWgeB2W9rWxvrhVzlt+s0Q5+c0S5eQ3S5ST3yxR1RbtGOpj+ObFk9oGy/uj2mrgiWLH17atKwpthw9XEU1j5u0sP0fvW5eeVmg74szitqCsNLl1Lo3VP6+3/GaJcvKbJcrJb5YoJ79Zohqp2HOlpJ9K2pjfVrU+XDNrlkYq9gD8eUR85SDPnaRnBOZv74K61ocQjZe/30PPOrrQNvho8bTm6On8U5jtST2z+PWsnjH8Aiir2GNmXWxOFXsi4pZ80gck3SnpMkmlRd0mVezZv6dJYZtZo+ZUsUfSLwPvAZ4NPA9YBPzlNM99smLP4PwmhW1mjZprxZ6zImJ7XsTzAPA5PIy3WVeZc8UeScsiYrskkVXo3TTTa4VgzCNCVqy8w2500NfuH4pmU0SqkYo938m/GARsBP54DrGaWZs0UrHnjJZEZGaV8Bl+Zoly8pslqtLr+cf7Ye9SnzFm1iqzGX/BW36zRDn5zRLl5DdLlJPfLFGVdvj1jMDhO3xBoFmr9IzMYt7WhWFmnczJb5YoJ79Zopz8ZomqtMNPMbuKImY2OyWV5KflLb9Zopz8Zoly8pslyslvlqi6kz8fvvsHkr6ZP36qpFskbZZ0jSSPzmfWRWbT2/824G7giPzxh4HLIuJqSZ8CLgI+ebAX6BkOFm6ZUtS+pCLMeG/5Nf8Du4YLbRorVplphfGB4oCXIwuLF0+X9bb2DRXPuewZruZnj+gt/34fPqr4Xd0zVgxeo8W2vqHi/6Eqw0eXlodAJbFHX3HdB3YdKD55mqpGzTY2v/h5GRsoxjhdlaR6Pv+9++vPh3qLdiwHXgl8Jn8s4AxgolTXWrIRfM2sS9S72/8x4C+Aia+VxcCuiJioDLYVOL7sibUVe0ZGXLHHrFPUU6X3HGBnRHy/trlk1tJ9p9qKPf39rthj1inqOeZ/IfBqSWcDg2TH/B8DjpLUl2/9lwPbWhemmTVbPeP2v4esLh+STgfeFREXSPoy8FrgamAN8I2ZXmt0vtixenKHzciC4g7D+EB5B8xxNxU7e/qfqKbjbOj4YgfZL1aVxFnStPjO4tu8YGs1nWYjC8sr82x7cXGnr/dAcYdu4PFi27G3Nx7XXP38JeUjVPbtK8ZZ9tlasaG+Ts1WePTk4ud3aEVJR2VveTxP+d5goe2wxybX5J5NSfVGfuf/S+Adku4l6wP4bAOvZWYVm9WFPRFxA1mhTiLiflyc06xr+Qw/s0Q5+c0SpYjqBtSU9DDwYP7wGOCRyhbeWofSuoDXp9MdbH1OiIgl9bxIpck/acHS7RFxalsW3mSH0rqA16fTNWt9vNtvlignv1mi2pn8l7dx2c12KK0LeH06XVPWp23H/GbWXt7tN0uUk98sUZUnv6SzJP2vpHslvbvq5TdK0hWSdkraVNO2SNL6fEiz9ZKObmeMsyFphaTrJd0t6S5Jb8vbu26dJA1KulXSD/N1eV/e3tVDzrVqCL1Kk19SL/AJ4BXAycD5kk6uMoYmuBI4a0rbu4ENEbES2JA/7hajwDsj4iTgNOBP8v9JN67TAeCMiHgusAo4S9JpPDnk3ErgMbIh57rJxBB6E5qyPlVv+VcD90bE/RExTHY58LkVx9CQiLgReHRK87lkQ5lBlw1pFhHbI+KO/P5usg/Z8XThOkVmKH/Yn9+CLh5yrpVD6FWd/McDW2oeTzv8V5dZGhHbIUsm4Ng2xzMnkk4EfhW4hS5dp3wXeSOwE1gP3EedQ851qDkPoTeTqpO/7uG/rFqSFgBfBf4sIp5odzxzFRFjEbGKbHSp1cBJZbNVG9XcNDqE3kwqLdRJ9i21oubxoTL81w5JyyJiu6RlZFudriGpnyzxvxARX8ubu3qdImKXpBvI+jG6dci5lg6hV/WW/zZgZd5bOQC8AVhXcQytsI5sKDOoc0izTpEfQ34WuDsiPlozqevWSdISSUfl9+cBLyPrw7iebMg56JJ1gWwIvYhYHhEnkuXKdyLiApq1PhFR6Q04G/gJ2bHYX1W9/CbEfxWwHRgh25O5iOw4bAOwOf+7qN1xzmJ9XkS223gnsDG/nd2N6wScAvwgX5dNwN/m7U8DbgXuBb4MHNbuWOewbqcD32zm+vj0XrNE+Qw/s0Q5+c0S5eQ3S5ST3yxRTn6zRDn5zRLl5DdL1P8B8FPBd33wU/8AAAAASUVORK5CYII=",
            "text/plain": [
              "<matplotlib.figure.Figure at 0x7fc9fb6365c0>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "s, _ = env.reset()\n",
        "for _ in range(100):\n",
        "    s, _, _, _, _ = env.step(env.action_space.sample())\n",
        "\n",
        "plt.title('Game image')\n",
        "plt.imshow(env.render())\n",
        "plt.show()\n",
        "\n",
        "plt.title('Agent observation')\n",
        "plt.imshow(s.reshape([42, 42]))\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### POMDP setting\n",
        "\n",
        "The Atari game we're working with is actually a POMDP: your agent needs to know timing at which enemies spawn and move, but cannot do so unless it has some memory. \n",
        "\n",
        "Let's design another agent that has a recurrent neural net memory to solve this. Here's a sketch.\n",
        "\n",
        "![img](img1.jpg)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "class SimpleRecurrentAgent(nn.Module):\n",
        "    def __init__(self, obs_shape, n_actions, reuse=False):\n",
        "        \"\"\"A simple actor-critic agent\"\"\"\n",
        "        super(self.__class__, self).__init__()\n",
        "\n",
        "        self.conv0 = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2))\n",
        "        self.conv1 = nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))\n",
        "        self.conv2 = nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))\n",
        "        self.flatten = nn.Flatten()\n",
        "\n",
        "        self.hid = nn.Linear(512, 128)\n",
        "        self.rnn = nn.LSTMCell(128, 128)\n",
        "\n",
        "        self.logits = nn.Linear(128, n_actions)\n",
        "        self.state_value = nn.Linear(128, 1)\n",
        "\n",
        "    def forward(self, prev_state, obs_t):\n",
        "        \"\"\"\n",
        "        Takes agent's previous hidden state and a new observation,\n",
        "        returns a new hidden state and whatever the agent needs to learn\n",
        "        \"\"\"\n",
        "\n",
        "        # Apply the whole neural net for one step here.\n",
        "        # See docs on self.rnn(...).\n",
        "        # The recurrent cell should take the last feedforward dense layer as input.\n",
        "        <YOUR CODE>\n",
        "\n",
        "        new_state = <YOUR CODE>\n",
        "        logits = <YOUR CODE>\n",
        "        state_value = <YOUR CODE>\n",
        "\n",
        "        return new_state, (logits, state_value)\n",
        "\n",
        "    def get_initial_state(self, batch_size):\n",
        "        \"\"\"Return a list of agent memory states at game start. Each state is a np array of shape [batch_size, ...]\"\"\"\n",
        "        return torch.zeros((batch_size, 128)), torch.zeros((batch_size, 128))\n",
        "\n",
        "    def sample_actions(self, agent_outputs):\n",
        "        \"\"\"pick actions given numeric agent outputs (np arrays)\"\"\"\n",
        "        logits, state_values = agent_outputs\n",
        "        probs = F.softmax(logits, dim=-1)\n",
        "        return torch.multinomial(probs, 1)[:, 0].data.numpy()\n",
        "\n",
        "    def step(self, prev_state, obs_t):\n",
        "        \"\"\" like forward, but obs_t is a numpy array \"\"\"\n",
        "        obs_t = torch.tensor(np.asarray(obs_t), dtype=torch.float32)\n",
        "        (h, c), (l, s) = self.forward(prev_state, obs_t)\n",
        "        return (h.detach(), c.detach()), (l.detach(), s.detach())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "n_parallel_games = 5\n",
        "gamma = 0.99\n",
        "\n",
        "agent = SimpleRecurrentAgent(obs_shape, n_actions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "state = [env.reset()[0]]\n",
        "_, (logits, value) = agent.step(agent.get_initial_state(1), state)\n",
        "print(\"action logits:\\n\", logits)\n",
        "print(\"state values:\\n\", value)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Let's play!\n",
        "Let's build a function that measures agent's average reward."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def evaluate(agent, env, n_games=1):\n",
        "    \"\"\"Plays an entire game start to end, returns session rewards.\"\"\"\n",
        "\n",
        "    game_rewards = []\n",
        "    for _ in range(n_games):\n",
        "        # initial observation and memory\n",
        "        observation, _ = env.reset()\n",
        "        prev_memories = agent.get_initial_state(1)\n",
        "\n",
        "        total_reward = 0\n",
        "        while True:\n",
        "            new_memories, readouts = agent.step(\n",
        "                prev_memories, observation[None, ...])\n",
        "            action = agent.sample_actions(readouts)\n",
        "\n",
        "            observation, reward, terminated, truncated, info = env.step(action[0])\n",
        "\n",
        "            total_reward += reward\n",
        "            prev_memories = new_memories\n",
        "            if terminated or truncated:\n",
        "                break\n",
        "\n",
        "        game_rewards.append(total_reward)\n",
        "    return game_rewards"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from gymnasium.wrappers import RecordVideo\n",
        "\n",
        "with make_env() as record_env, RecordVideo(record_env, video_folder=\"videos\") as env_monitor:\n",
        "    rewards = evaluate(agent, env_monitor, n_games=3)\n",
        "\n",
        "print(rewards)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Show video. This may not work in some setups. If it doesn't\n",
        "# work for you, you can download the videos and view them locally.\n",
        "\n",
        "from pathlib import Path\n",
        "from base64 import b64encode\n",
        "from IPython.display import HTML\n",
        "\n",
        "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n",
        "video_path = video_paths[-1]  # You can also try other indices\n",
        "\n",
        "if 'google.colab' in sys.modules:\n",
        "    # https://stackoverflow.com/a/57378660/1214547\n",
        "    with video_path.open('rb') as fp:\n",
        "        mp4 = fp.read()\n",
        "    data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n",
        "else:\n",
        "    data_url = str(video_path)\n",
        "\n",
        "HTML(\"\"\"\n",
        "<video width=\"640\" height=\"480\" controls>\n",
        "  <source src=\"{}\" type=\"video/mp4\">\n",
        "</video>\n",
        "\"\"\".format(data_url))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Training on parallel games\n",
        "\n",
        "We introduce a class called EnvPool - it's a tool that handles multiple environments for you. Here's how it works:\n",
        "![img](img2.jpg)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from env_pool import EnvPool\n",
        "pool = EnvPool(agent, make_env, n_parallel_games)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We gonna train our agent on a thing called __rollouts:__\n",
        "![img](img3.jpg)\n",
        "\n",
        "A rollout is just a sequence of T observations, actions and rewards that agent took consequently.\n",
        "* First __s0__ is not necessarily initial state for the environment\n",
        "* Final state is not necessarily terminal\n",
        "* We sample several parallel rollouts for efficiency"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# for each of n_parallel_games, take 10 steps\n",
        "rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print(\"Actions shape:\", rollout_actions.shape)\n",
        "print(\"Rewards shape:\", rollout_rewards.shape)\n",
        "print(\"Mask shape:\", rollout_mask.shape)\n",
        "print(\"Observations shape: \", rollout_obs.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Actor-critic objective\n",
        "\n",
        "Here we define a loss function that uses rollout above to train advantage actor-critic agent.\n",
        "\n",
        "\n",
        "Our loss consists of three components:\n",
        "\n",
        "* __The policy \"loss\"__\n",
        " $$ \\hat J = {1 \\over T} \\cdot \\sum_t { \\log \\pi(a_t | s_t) } \\cdot A_{const}(s,a) $$\n",
        "  * This function has no meaning in and of itself, but it was built such that\n",
        "  * $ \\nabla \\hat J = {1 \\over N} \\cdot \\sum_t { \\nabla \\log \\pi(a_t | s_t) } \\cdot A(s,a) \\approx \\nabla E_{s, a \\sim \\pi} R(s,a) $\n",
        "  * Therefore if we __maximize__ J_hat with gradient descent we will maximize expected reward\n",
        "  \n",
        "  \n",
        "* __The value \"loss\"__\n",
        "  $$ L_{td} = {1 \\over T} \\cdot \\sum_t { [r + \\gamma \\cdot V_{const}(s_{t+1}) - V(s_t)] ^ 2 }$$\n",
        "  * Ye Olde TD_loss from q-learning and alike\n",
        "  * If we minimize this loss, V(s) will converge to $V_\\pi(s) = E_{a \\sim \\pi(a | s)} R(s,a) $\n",
        "\n",
        "\n",
        "* __Entropy Regularizer__\n",
        "  $$ H = - {1 \\over T} \\sum_t \\sum_a {\\pi(a|s_t) \\cdot \\log \\pi (a|s_t)}$$\n",
        "  * If we __maximize__ entropy we discourage agent from predicting zero probability to actions\n",
        "  prematurely (a.k.a. exploration)\n",
        "  \n",
        "  \n",
        "So we optimize a linear combination of $L_{td}$ $- \\hat J$, $-H$\n",
        "  \n",
        "```\n",
        "\n",
        "```\n",
        "\n",
        "```\n",
        "\n",
        "```\n",
        "\n",
        "```\n",
        "\n",
        "```\n",
        "\n",
        "\n",
        "__One more thing:__ since we train on T-step rollouts, we can use N-step formula for advantage for free:\n",
        "  * At the last step, $A(s_t,a_t) = r(s_t, a_t) + \\gamma \\cdot V(s_{t+1}) - V(s) $\n",
        "  * One step earlier, $A(s_t,a_t) = r(s_t, a_t) + \\gamma \\cdot r(s_{t+1}, a_{t+1}) + \\gamma ^ 2 \\cdot V(s_{t+2}) - V(s) $\n",
        "  * Et cetera, et cetera. This way agent starts training much faster since it's estimate of A(s,a) depends less on his (imperfect) value function and more on actual rewards. There's also a [nice generalization](https://arxiv.org/abs/1506.02438) of this.\n",
        "\n",
        "\n",
        "__Note:__ it's also a good idea to scale rollout_len up to learn longer sequences. You may wish set it to >=20 or to start at 10 and then scale up as time passes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "opt = torch.optim.Adam(agent.parameters(), lr=1e-5)\n",
        "\n",
        "\n",
        "def train_on_rollout(states, actions, rewards, is_not_done, prev_memory_states, gamma=0.99):\n",
        "    \"\"\"\n",
        "    Takes a sequence of states, actions and rewards produced by generate_session.\n",
        "    Updates agent's weights by following the policy gradient above.\n",
        "    Please use Adam optimizer with default parameters.\n",
        "    \"\"\"\n",
        "\n",
        "    # shape: [batch_size, time, c, h, w]\n",
        "    states = torch.tensor(np.asarray(states), dtype=torch.float32)\n",
        "    actions = torch.tensor(np.array(actions), dtype=torch.int64)  # shape: [batch_size, time]\n",
        "    rewards = torch.tensor(np.array(rewards), dtype=torch.float32)  # shape: [batch_size, time]\n",
        "    is_not_done = torch.tensor(np.array(is_not_done), dtype=torch.float32)  # shape: [batch_size, time]\n",
        "    rollout_length = rewards.shape[1] - 1\n",
        "\n",
        "    # predict logits, probas and log-probas using an agent.\n",
        "    memory = [m.detach() for m in prev_memory_states]\n",
        "\n",
        "    logits = []  # append logit sequence here\n",
        "    state_values = []  # append state values here\n",
        "    for t in range(rewards.shape[1]):\n",
        "        obs_t = states[:, t]\n",
        "\n",
        "        # use agent to comute logits_t and state values_t.\n",
        "        # append them to logits and state_values array\n",
        "\n",
        "        memory, (logits_t, values_t) = <YOUR CODE>\n",
        "\n",
        "        logits.append(logits_t)\n",
        "        state_values.append(values_t)\n",
        "\n",
        "    logits = torch.stack(logits, dim=1)\n",
        "    state_values = torch.stack(state_values, dim=1)\n",
        "    probas = F.softmax(logits, dim=2)\n",
        "    logprobas = F.log_softmax(logits, dim=2)\n",
        "\n",
        "    # select log-probabilities for chosen actions, log pi(a_i|s_i)\n",
        "    actions_one_hot = F.one_hot(actions, n_actions).view(\n",
        "        actions.shape[0], actions.shape[1], n_actions)\n",
        "    logprobas_for_actions = torch.sum(logprobas * actions_one_hot, dim=-1)\n",
        "\n",
        "    # Now let's compute two loss components:\n",
        "    # 1) Policy gradient objective.\n",
        "    # Notes: Please don't forget to call .detach() on advantage term. Also please use mean, not sum.\n",
        "    # it's okay to use loops if you want\n",
        "    J_hat = 0  # policy objective as in the formula for J_hat\n",
        "\n",
        "    # 2) Temporal difference MSE for state values\n",
        "    # Notes: Please don't forget to call .detach() on V(s') term. Also please use mean, not sum.\n",
        "    # it's okay to use loops if you want\n",
        "    value_loss = 0\n",
        "\n",
        "    cumulative_returns = state_values[:, -1].detach()\n",
        "\n",
        "    for t in reversed(range(rollout_length)):\n",
        "        r_t = rewards[:, t]                                # current rewards\n",
        "        # current state values\n",
        "        V_t = state_values[:, t]\n",
        "        V_next = state_values[:, t + 1].detach()           # next state values\n",
        "        # log-probability of a_t in s_t\n",
        "        logpi_a_s_t = logprobas_for_actions[:, t]\n",
        "\n",
        "        # update G_t = r_t + gamma * G_{t+1} as we did in week6 reinforce\n",
        "        cumulative_returns = r_t + gamma * cumulative_returns\n",
        "\n",
        "        # Compute temporal difference error (MSE for V(s))\n",
        "        value_loss += <YOUR CODE>\n",
        "\n",
        "        # compute advantage A(s_t, a_t) using cumulative returns and V(s_t) as baseline\n",
        "        advantage = <YOUR CODE>\n",
        "        advantage = advantage.detach()\n",
        "\n",
        "        # compute policy pseudo-loss aka -J_hat.\n",
        "        J_hat += <YOUR CODE>\n",
        "\n",
        "    # regularize with entropy\n",
        "    entropy_reg = <YOUR CODE: compute entropy regularizer>\n",
        "\n",
        "    # add-up three loss components and average over time\n",
        "    loss = -J_hat / rollout_length +\\\n",
        "        value_loss / rollout_length +\\\n",
        "           -0.01 * entropy_reg\n",
        "\n",
        "    # Gradient descent step\n",
        "    <YOUR CODE>\n",
        "\n",
        "    return loss.data.numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# let's test it\n",
        "memory = list(pool.prev_memory_states)\n",
        "rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)\n",
        "\n",
        "train_on_rollout(rollout_obs, rollout_actions,\n",
        "                 rollout_rewards, rollout_mask, memory)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Train \n",
        "\n",
        "just run train step and see if agent learns any better"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from IPython.display import clear_output\n",
        "from tqdm import trange\n",
        "from pandas import DataFrame\n",
        "moving_average = lambda x, **kw: DataFrame(\n",
        "    {'x': np.asarray(x)}).x.ewm(**kw).mean().values\n",
        "\n",
        "rewards_history = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "for i in trange(15000):\n",
        "\n",
        "    memory = list(pool.prev_memory_states)\n",
        "    rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(\n",
        "        10)\n",
        "    train_on_rollout(rollout_obs, rollout_actions,\n",
        "                     rollout_rewards, rollout_mask, memory)\n",
        "\n",
        "    if i % 100 == 0:\n",
        "        rewards_history.append(np.mean(evaluate(agent, env, n_games=1)))\n",
        "        clear_output(True)\n",
        "        plt.plot(rewards_history, label='rewards')\n",
        "        plt.plot(moving_average(np.array(rewards_history),\n",
        "                                span=10), label='rewards ewma@10')\n",
        "        plt.legend()\n",
        "        plt.show()\n",
        "        if rewards_history[-1] >= 10000:\n",
        "            print(\"Your agent has just passed the minimum homework threshold\")\n",
        "            break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Relax and grab some refreshments while your agent is locked in an infinite loop of violence and death.\n",
        "\n",
        "__How to interpret plots:__\n",
        "\n",
        "The session reward is the easy thing: it should in general go up over time, but it's okay if it fluctuates ~~like crazy~~. It's also OK if it reward doesn't increase substantially before some 10k initial steps. However, if reward reaches zero and doesn't seem to get up over 2-3 evaluations, there's something wrong happening.\n",
        "\n",
        "\n",
        "Since we use a policy-based method, we also keep track of __policy entropy__ - the same one you used as a regularizer. The only important thing about it is that your entropy shouldn't drop too low (`< 0.1`) before your agent gets the yellow belt. Or at least it can drop there, but _it shouldn't stay there for long_.\n",
        "\n",
        "If it does, the culprit is likely:\n",
        "* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot log p(a_i) $\n",
        "* Your agent architecture converges too fast. Increase entropy coefficient in actor loss. \n",
        "* Gradient explosion - just [clip gradients](https://stackoverflow.com/a/56069467) and maybe use a smaller network\n",
        "* Us. Or PyTorch developers. Or aliens. Or lizardfolk. Contact us on forums before it's too late!\n",
        "\n",
        "If you're debugging, just run `logits, values = agent.step(batch_states)` and manually look into logits and values. This will reveal the problem 9 times out of 10: you'll likely see some NaNs or insanely large numbers or zeros. Try to catch the moment when this happens for the first time and investigate from there."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### \"Final\" evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from gymnasium.wrappers import RecordVideo\n",
        "\n",
        "with make_env() as record_env, RecordVideo(record_env, video_folder=\"videos\") as env_monitor:\n",
        "    final_rewards = evaluate(agent, env_monitor, n_games=20)\n",
        "\n",
        "print(\"Final mean reward\", np.mean(final_rewards))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Show video. This may not work in some setups. If it doesn't\n",
        "# work for you, you can download the videos and view them locally.\n",
        "\n",
        "from pathlib import Path\n",
        "from base64 import b64encode\n",
        "from IPython.display import HTML\n",
        "\n",
        "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n",
        "video_path = video_paths[-1]  # You can also try other indices\n",
        "\n",
        "if 'google.colab' in sys.modules:\n",
        "    # https://stackoverflow.com/a/57378660/1214547\n",
        "    with video_path.open('rb') as fp:\n",
        "        mp4 = fp.read()\n",
        "    data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n",
        "else:\n",
        "    data_url = str(video_path)\n",
        "\n",
        "HTML(\"\"\"\n",
        "<video width=\"640\" height=\"480\" controls>\n",
        "  <source src=\"{}\" type=\"video/mp4\">\n",
        "</video>\n",
        "\"\"\".format(data_url))"
      ]
    }
  ],
  "metadata": {
    "language_info": {
      "name": "python",
      "pygments_lexer": "ipython3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 1
}
